{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-2.16257178, -1.71055526, -1.6717897 , ...,  1.20444518,\n",
       "       -0.60166689, -0.64335253])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "#生成数据,alpha描述了两个集合的数据是不均衡的\n",
    "alpha = np.array([0.3, 0.7])\n",
    "mu = np.array([-2, 0.5])\n",
    "sigmod = np.array([0.5, 1])\n",
    "\n",
    "data0 = np.random.normal(mu[0], sigmod[0], int(2000 * alpha[0]))\n",
    "data1 = np.random.normal(mu[1], sigmod[1], int(2000 * alpha[1]))\n",
    "\n",
    "data = np.concatenate((data0, data1))\n",
    "\n",
    "#初始化参数\n",
    "alpha = np.array([0.5, 0.5])\n",
    "mu = np.array([0, 1])\n",
    "sigmod = np.array([1, 1])\n",
    "\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.39778096, 0.24648743, 0.37877552, 0.09648062, 0.10488653,\n",
       "       0.39893111, 0.3985222 , 0.35443724, 0.37572414, 0.36849207])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import math\n",
    "\n",
    "#高斯响应函数,衡量了每个数据和高斯分布吻合的程度,数字越大,越匹配\n",
    "\n",
    "\n",
    "def gauss(data, mu, sigmod):\n",
    "    left = 1 / (math.sqrt(2 * math.pi) * sigmod**2)\n",
    "    right = np.exp(-1 * (data - mu)**2 / (2 * sigmod**2))\n",
    "    return left * right\n",
    "\n",
    "\n",
    "gauss(np.random.normal(5, 1, 10), 5, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1248.66309497147, 751.33690502853, 2000.0)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def E():\n",
    "    gamma = np.empty((2, len(data)))\n",
    "\n",
    "    #得到数据对当前分部的吻合程度,乘以alpha是要考虑数据权重\n",
    "    #但是现在还不知道哪些数据是哪个分布里的,只能全部都乘以alpha\n",
    "    #这样在后面收敛的时候,吻合度大的数据会占据更大的alpha,自然就能收敛出正确的alpha\n",
    "    gamma[0] = alpha[0] * gauss(data, mu[0], sigmod[0])\n",
    "    gamma[1] = alpha[1] * gauss(data, mu[1], sigmod[1])\n",
    "\n",
    "    #转换为概率\n",
    "    _sum = gamma[0] + gamma[1]\n",
    "\n",
    "    gamma[0] = gamma[0] / _sum\n",
    "    gamma[1] = gamma[1] / _sum\n",
    "\n",
    "    return gamma\n",
    "\n",
    "\n",
    "gamma = E()\n",
    "gamma[0].sum(), gamma[1].sum(), gamma.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([-0.8583791 ,  0.76196917]),\n",
       " array([1.56093929, 1.1690965 ]),\n",
       " array([0.62433155, 0.37566845]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def M(gamma):\n",
    "    #经过计算之后会生成新的mu,sigmod,alpha\n",
    "    mu_new = np.empty(2)\n",
    "    sigmod_new = np.empty(2)\n",
    "    alpha_new = np.empty(2)\n",
    "    \n",
    "    #计算新的mu\n",
    "    #匹配和数据分别相乘再相加,等于是根据匹配度给数据一个权重,相加之后等于带权重的和\n",
    "    #再除以权重的和,等于求上一步带权重和的均值.\n",
    "    mu_new[0] = np.dot(gamma[0], data) / gamma[0].sum()\n",
    "    mu_new[1] = np.dot(gamma[1], data) / gamma[1].sum()\n",
    "\n",
    "    #计算新的sigmoid\n",
    "    #这个是现在的方差\n",
    "    var = (data - mu[0])**2\n",
    "\n",
    "    #这个和上一步差不多,带权方差求均值\n",
    "    var = np.dot(gamma[0], var) / gamma[0].sum()\n",
    "\n",
    "    #因为是方差,这里开方得到标准差\n",
    "    sigmod_new[0] = math.sqrt(var)\n",
    "\n",
    "    var = (data - mu[1])**2\n",
    "    var = np.dot(gamma[1], var) / gamma[1].sum()\n",
    "    sigmod_new[1] = math.sqrt(var)\n",
    "\n",
    "    #计算新的alpha\n",
    "    #其实就是gamma求均值即可\n",
    "    alpha_new[0] = gamma[0].mean()\n",
    "    alpha_new[1] = gamma[1].mean()\n",
    "\n",
    "    return mu_new, sigmod_new, alpha_new\n",
    "\n",
    "\n",
    "M(gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 [0.62433155 0.37566845] [-0.8583791   0.76196917] [1.56093929 1.1690965 ]\n",
      "10 [0.43139131 0.56860869] [-1.62416087  0.79313512] [0.82322768 0.88385842]\n",
      "20 [0.38639765 0.61360235] [-1.79639422  0.72434066] [0.66943875 0.8931643 ]\n",
      "30 [0.37829012 0.62170988] [-1.82469136  0.70868639] [0.64436688 0.89917189]\n",
      "40 [0.37713856 0.62286144] [-1.82865702  0.7064038 ] [0.64085916 0.90010918]\n",
      "50 [0.37698326 0.62301674] [-1.82919066  0.70609481] [0.64038739 0.90023728]\n",
      "60 [0.37696248 0.62303752] [-1.82926205  0.70605344] [0.64032428 0.90025445]\n",
      "70 [0.3769597 0.6230403] [-1.82927159  0.70604791] [0.64031585 0.90025675]\n",
      "80 [0.37695933 0.62304067] [-1.82927287  0.70604717] [0.64031472 0.90025706]\n",
      "90 [0.37695928 0.62304072] [-1.82927304  0.70604707] [0.64031457 0.9002571 ]\n"
     ]
    }
   ],
   "source": [
    "#训练目标\n",
    "#0.3, 0.7    -2, 0.5    0.5, 1\n",
    "\n",
    "#训练\n",
    "for i in range(100):\n",
    "    gamma = E()\n",
    "    mu, sigmod, alpha = M(gamma)\n",
    "    if i % 10 == 0:\n",
    "        print(i, alpha, mu, sigmod)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
